class UnionFind {
public:
    UnionFind(int n) {
        p = vector<int>(n);
        size = vector<int>(n, 1);
        iota(p.begin(), p.end(), 0);
    }

    bool unite(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa == pb) {
            return false;
        }
        if (size[pa] > size[pb]) {
            p[pb] = pa;
            size[pa] += size[pb];
        } else {
            p[pa] = pb;
            size[pb] += size[pa];
        }
        return true;
    }

    int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    int getSize(int root) {
        return size[root];
    }

private:
    vector<int> p, size;
};

class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        UnionFind uf(n);
        for (int i = 0; i < n; ++i) {
            for (int j = i + 1; j < n; ++j) {
                if (graph[i][j]) {
                    uf.unite(i, j);
                }
            }
        }
        int ans = n, mx = 0;
        vector<int> cnt(n);
        for (int x : initial) {
            ++cnt[uf.find(x)];
        }
        for (int x : initial) {
            int root = uf.find(x);
            if (cnt[root] == 1) {
                int sz = uf.getSize(root);
                if (sz > mx || (sz == mx && ans > x)) {
                    ans = x;
                    mx = sz;
                }
            }
        }
        return ans == n ? *min_element(initial.begin(), initial.end()) : ans;
    }
};